import copy
import os
import pickle

import numpy as np
import SharedArray
import torch
import torch.distributed as dist
from pcdet.datasets.kitti.kitti_object_eval_python import kitti_common
from skimage import io

from ...ops.iou3d_nms import iou3d_nms_utils
from ...utils import box_utils, calibration_kitti, common_utils


class DataBaseSampler(object):
    def __init__(self, root_path, sampler_cfg, class_names, logger=None):
        self.root_path = root_path
        self.class_names = class_names
        self.sampler_cfg = sampler_cfg

        self.img_aug_type = sampler_cfg.get("IMG_AUG_TYPE", None)
        self.img_aug_iou_thresh = sampler_cfg.get("IMG_AUG_IOU_THRESH", 0.5)

        self.logger = logger
        self.db_infos = {}
        for class_name in class_names:
            self.db_infos[class_name] = []

        self.use_shared_memory = sampler_cfg.get("USE_SHARED_MEMORY", False)

        for db_info_path in sampler_cfg.DB_INFO_PATH:
            db_info_path = self.root_path.resolve() / db_info_path
            if not db_info_path.exists():
                assert len(sampler_cfg.DB_INFO_PATH) == 1
                sampler_cfg.DB_INFO_PATH[0] = sampler_cfg.BACKUP_DB_INFO["DB_INFO_PATH"]
                sampler_cfg.DB_DATA_PATH[0] = sampler_cfg.BACKUP_DB_INFO["DB_DATA_PATH"]
                db_info_path = self.root_path.resolve() / sampler_cfg.DB_INFO_PATH[0]
                sampler_cfg.NUM_POINT_FEATURES = sampler_cfg.BACKUP_DB_INFO["NUM_POINT_FEATURES"]

            with open(str(db_info_path), "rb") as f:
                infos = pickle.load(f)
                [self.db_infos[cur_class].extend(infos[cur_class]) for cur_class in class_names]

        for func_name, val in sampler_cfg.PREPARE.items():
            self.db_infos = getattr(self, func_name)(self.db_infos, val)

        self.gt_database_data_key = (
            self.load_db_to_shared_memory() if self.use_shared_memory else None
        )

        self.sample_groups = {}
        self.sample_class_num = {}
        self.limit_whole_scene = sampler_cfg.get("LIMIT_WHOLE_SCENE", False)

        for x in sampler_cfg.SAMPLE_GROUPS:
            class_name, sample_num = x.split(":")
            if class_name not in class_names:
                continue
            self.sample_class_num[class_name] = sample_num
            self.sample_groups[class_name] = {
                "sample_num": sample_num,
                "pointer": len(self.db_infos[class_name]),
                "indices": np.arange(len(self.db_infos[class_name])),
            }

    def __getstate__(self):
        d = dict(self.__dict__)
        del d["logger"]
        return d

    def __setstate__(self, d):
        self.__dict__.update(d)

    def __del__(self):
        if self.use_shared_memory:
            self.logger.info("Deleting GT database from shared memory")
            cur_rank, num_gpus = common_utils.get_dist_info()
            sa_key = self.sampler_cfg.DB_DATA_PATH[0]
            if cur_rank % num_gpus == 0 and os.path.exists(f"/dev/shm/{sa_key}"):
                SharedArray.delete(f"shm://{sa_key}")

            if num_gpus > 1:
                dist.barrier()
            self.logger.info("GT database has been removed from shared memory")

    def load_db_to_shared_memory(self):
        self.logger.info("Loading GT database to shared memory")
        cur_rank, world_size, num_gpus = common_utils.get_dist_info(return_gpu_per_machine=True)

        assert self.sampler_cfg.DB_DATA_PATH.__len__() == 1, "Current only support single DB_DATA"
        db_data_path = self.root_path.resolve() / self.sampler_cfg.DB_DATA_PATH[0]
        sa_key = self.sampler_cfg.DB_DATA_PATH[0]

        if cur_rank % num_gpus == 0 and not os.path.exists(f"/dev/shm/{sa_key}"):
            gt_database_data = np.load(db_data_path)
            common_utils.sa_create(f"shm://{sa_key}", gt_database_data)

        if num_gpus > 1:
            dist.barrier()
        self.logger.info("GT database has been saved to shared memory")
        return sa_key

    def filter_by_difficulty(self, db_infos, removed_difficulty):
        new_db_infos = {}
        for key, dinfos in db_infos.items():
            pre_len = len(dinfos)
            new_db_infos[key] = [
                info for info in dinfos if info["difficulty"] not in removed_difficulty
            ]
            if self.logger is not None:
                self.logger.info(
                    "Database filter by difficulty %s: %d => %d"
                    % (key, pre_len, len(new_db_infos[key]))
                )
        return new_db_infos

    def filter_by_min_points(self, db_infos, min_gt_points_list):
        for name_num in min_gt_points_list:
            name, min_num = name_num.split(":")
            min_num = int(min_num)
            if min_num > 0 and name in db_infos.keys():
                filtered_infos = []
                for info in db_infos[name]:
                    if info["num_points_in_gt"] >= min_num:
                        filtered_infos.append(info)

                if self.logger is not None:
                    self.logger.info(
                        "Database filter by min points %s: %d => %d"
                        % (name, len(db_infos[name]), len(filtered_infos))
                    )
                db_infos[name] = filtered_infos

        return db_infos

    def sample_with_fixed_number(self, class_name, sample_group):
        """
        Args:
            class_name:
            sample_group:
        Returns:

        """
        sample_num, pointer, indices = (
            int(sample_group["sample_num"]),
            sample_group["pointer"],
            sample_group["indices"],
        )
        if pointer >= len(self.db_infos[class_name]):
            indices = np.random.permutation(len(self.db_infos[class_name]))
            pointer = 0

        sampled_dict = [
            self.db_infos[class_name][idx] for idx in indices[pointer : pointer + sample_num]
        ]
        pointer += sample_num
        sample_group["pointer"] = pointer
        sample_group["indices"] = indices
        return sampled_dict

    @staticmethod
    def put_boxes_on_road_planes(gt_boxes, road_planes, calib):
        """
        Only validate in KITTIDataset
        Args:
            gt_boxes: (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]
            road_planes: [a, b, c, d]
            calib:

        Returns:
        """
        a, b, c, d = road_planes
        center_cam = calib.lidar_to_rect(gt_boxes[:, 0:3])
        cur_height_cam = (-d - a * center_cam[:, 0] - c * center_cam[:, 2]) / b
        center_cam[:, 1] = cur_height_cam
        cur_lidar_height = calib.rect_to_lidar(center_cam)[:, 2]
        mv_height = gt_boxes[:, 2] - gt_boxes[:, 5] / 2 - cur_lidar_height
        gt_boxes[:, 2] -= mv_height  # lidar view
        return gt_boxes, mv_height

    def copy_paste_to_image_kitti(self, data_dict, crop_feat, gt_number, point_idxes=None):
        kitti_img_aug_type = "by_depth"
        kitti_img_aug_use_type = "annotation"

        image = data_dict["images"]
        boxes3d = data_dict["gt_boxes"]
        boxes2d = data_dict["gt_boxes2d"]
        corners_lidar = box_utils.boxes_to_corners_3d(boxes3d)
        if "depth" in kitti_img_aug_type:
            paste_order = boxes3d[:, 0].argsort()
            paste_order = paste_order[::-1]
        else:
            paste_order = np.arange(len(boxes3d), dtype=np.int)

        if "reverse" in kitti_img_aug_type:
            paste_order = paste_order[::-1]

        paste_mask = -255 * np.ones(image.shape[:2], dtype=np.int)
        fg_mask = np.zeros(image.shape[:2], dtype=np.int)
        overlap_mask = np.zeros(image.shape[:2], dtype=np.int)
        depth_mask = np.zeros((*image.shape[:2], 2), dtype=np.float)
        points_2d, depth_2d = data_dict["calib"].lidar_to_img(data_dict["points"][:, :3])
        points_2d[:, 0] = np.clip(points_2d[:, 0], a_min=0, a_max=image.shape[1] - 1)
        points_2d[:, 1] = np.clip(points_2d[:, 1], a_min=0, a_max=image.shape[0] - 1)
        points_2d = points_2d.astype(np.int)
        for _order in paste_order:
            _box2d = boxes2d[_order]
            image[_box2d[1] : _box2d[3], _box2d[0] : _box2d[2]] = crop_feat[_order]
            overlap_mask[_box2d[1] : _box2d[3], _box2d[0] : _box2d[2]] += (
                paste_mask[_box2d[1] : _box2d[3], _box2d[0] : _box2d[2]] > 0
            ).astype(np.int)
            paste_mask[_box2d[1] : _box2d[3], _box2d[0] : _box2d[2]] = _order

            if "cover" in kitti_img_aug_use_type:
                # HxWx2 for min and max depth of each box region
                depth_mask[_box2d[1] : _box2d[3], _box2d[0] : _box2d[2], 0] = corners_lidar[
                    _order, :, 0
                ].min()
                depth_mask[_box2d[1] : _box2d[3], _box2d[0] : _box2d[2], 1] = corners_lidar[
                    _order, :, 0
                ].max()

            # foreground area of original point cloud in image plane
            if _order < gt_number:
                fg_mask[_box2d[1] : _box2d[3], _box2d[0] : _box2d[2]] = 1

        data_dict["images"] = image

        # if not self.joint_sample:
        #     return data_dict

        new_mask = paste_mask[points_2d[:, 1], points_2d[:, 0]] == (point_idxes + gt_number)
        if False:  # self.keep_raw:
            raw_mask = point_idxes == -1
        else:
            raw_fg = (fg_mask == 1) & (paste_mask >= 0) & (paste_mask < gt_number)
            raw_bg = (fg_mask == 0) & (paste_mask < 0)
            raw_mask = (
                raw_fg[points_2d[:, 1], points_2d[:, 0]] | raw_bg[points_2d[:, 1], points_2d[:, 0]]
            )
        keep_mask = new_mask | raw_mask
        data_dict["points_2d"] = points_2d

        if "annotation" in kitti_img_aug_use_type:
            data_dict["points"] = data_dict["points"][keep_mask]
            data_dict["points_2d"] = data_dict["points_2d"][keep_mask]
        elif "projection" in kitti_img_aug_use_type:
            overlap_mask[overlap_mask >= 1] = 1
            data_dict["overlap_mask"] = overlap_mask
            if "cover" in kitti_img_aug_use_type:
                data_dict["depth_mask"] = depth_mask

        return data_dict

    def collect_image_crops_kitti(
        self, info, data_dict, obj_points, sampled_gt_boxes, sampled_gt_boxes2d, idx
    ):
        calib_file = kitti_common.get_calib_path(
            int(info["image_idx"]), self.root_path, relative_path=False
        )
        sampled_calib = calibration_kitti.Calibration(calib_file)
        points_2d, depth_2d = sampled_calib.lidar_to_img(obj_points[:, :3])

        if True:  # self.point_refine:
            # align calibration metrics for points
            points_ract = data_dict["calib"].img_to_rect(
                points_2d[:, 0], points_2d[:, 1], depth_2d
            )
            points_lidar = data_dict["calib"].rect_to_lidar(points_ract)
            obj_points[:, :3] = points_lidar
            # align calibration metrics for boxes
            box3d_raw = sampled_gt_boxes[idx].reshape(1, -1)
            box3d_coords = box_utils.boxes_to_corners_3d(box3d_raw)[0]
            box3d_box, box3d_depth = sampled_calib.lidar_to_img(box3d_coords)
            box3d_coord_rect = data_dict["calib"].img_to_rect(
                box3d_box[:, 0], box3d_box[:, 1], box3d_depth
            )
            box3d_rect = box_utils.corners_rect_to_camera(box3d_coord_rect).reshape(1, -1)
            box3d_lidar = box_utils.boxes3d_kitti_camera_to_lidar(box3d_rect, data_dict["calib"])
            box2d = box_utils.boxes3d_kitti_camera_to_imageboxes(
                box3d_rect, data_dict["calib"], data_dict["images"].shape[:2]
            )
            sampled_gt_boxes[idx] = box3d_lidar[0]
            sampled_gt_boxes2d[idx] = box2d[0]

        obj_idx = idx * np.ones(len(obj_points), dtype=np.int)

        # copy crops from images
        img_path = self.root_path / f'training/image_2/{info["image_idx"]}.png'
        raw_image = io.imread(img_path)
        raw_image = raw_image.astype(np.float32)
        raw_center = info["bbox"].reshape(2, 2).mean(0)
        new_box = sampled_gt_boxes2d[idx].astype(np.int)
        new_shape = np.array([new_box[2] - new_box[0], new_box[3] - new_box[1]])
        raw_box = np.concatenate([raw_center - new_shape / 2, raw_center + new_shape / 2]).astype(
            np.int
        )
        raw_box[0::2] = np.clip(raw_box[0::2], a_min=0, a_max=raw_image.shape[1])
        raw_box[1::2] = np.clip(raw_box[1::2], a_min=0, a_max=raw_image.shape[0])
        if (raw_box[2] - raw_box[0]) != new_shape[0] or (raw_box[3] - raw_box[1]) != new_shape[1]:
            new_center = new_box.reshape(2, 2).mean(0)
            new_shape = np.array([raw_box[2] - raw_box[0], raw_box[3] - raw_box[1]])
            new_box = np.concatenate(
                [new_center - new_shape / 2, new_center + new_shape / 2]
            ).astype(np.int)

        img_crop2d = raw_image[raw_box[1] : raw_box[3], raw_box[0] : raw_box[2]] / 255

        return new_box, img_crop2d, obj_points, obj_idx

    def sample_gt_boxes_2d_kitti(self, data_dict, sampled_boxes, valid_mask):
        mv_height = None
        # filter out box2d iou > thres
        if self.sampler_cfg.get("USE_ROAD_PLANE", False):
            sampled_boxes, mv_height = self.put_boxes_on_road_planes(
                sampled_boxes, data_dict["road_plane"], data_dict["calib"]
            )

        # sampled_boxes2d = np.stack([x['bbox'] for x in sampled_dict], axis=0).astype(np.float32)
        boxes3d_camera = box_utils.boxes3d_lidar_to_kitti_camera(sampled_boxes, data_dict["calib"])
        sampled_boxes2d = box_utils.boxes3d_kitti_camera_to_imageboxes(
            boxes3d_camera, data_dict["calib"], data_dict["images"].shape[:2]
        )
        sampled_boxes2d = torch.Tensor(sampled_boxes2d)
        existed_boxes2d = torch.Tensor(data_dict["gt_boxes2d"])
        iou2d1 = box_utils.pairwise_iou(sampled_boxes2d, existed_boxes2d).cpu().numpy()
        iou2d2 = box_utils.pairwise_iou(sampled_boxes2d, sampled_boxes2d).cpu().numpy()
        iou2d2[range(sampled_boxes2d.shape[0]), range(sampled_boxes2d.shape[0])] = 0
        iou2d1 = iou2d1 if iou2d1.shape[1] > 0 else iou2d2

        ret_valid_mask = (
            (iou2d1.max(axis=1) < self.img_aug_iou_thresh)
            & (iou2d2.max(axis=1) < self.img_aug_iou_thresh)
            & (valid_mask)
        )

        sampled_boxes2d = sampled_boxes2d[ret_valid_mask].cpu().numpy()
        if mv_height is not None:
            mv_height = mv_height[ret_valid_mask]
        return sampled_boxes2d, mv_height, ret_valid_mask

    def sample_gt_boxes_2d(self, data_dict, sampled_boxes, valid_mask):
        mv_height = None

        if self.img_aug_type == "kitti":
            sampled_boxes2d, mv_height, ret_valid_mask = self.sample_gt_boxes_2d_kitti(
                data_dict, sampled_boxes, valid_mask
            )
        else:
            raise NotImplementedError

        return sampled_boxes2d, mv_height, ret_valid_mask

    def initilize_image_aug_dict(self, data_dict, gt_boxes_mask):
        img_aug_gt_dict = None
        if self.img_aug_type is None:
            pass
        elif self.img_aug_type == "kitti":
            obj_index_list, crop_boxes2d = [], []
            gt_number = gt_boxes_mask.sum().astype(np.int)
            gt_boxes2d = data_dict["gt_boxes2d"][gt_boxes_mask].astype(np.int)
            gt_crops2d = [data_dict["images"][_x[1] : _x[3], _x[0] : _x[2]] for _x in gt_boxes2d]

            img_aug_gt_dict = {
                "obj_index_list": obj_index_list,
                "gt_crops2d": gt_crops2d,
                "gt_boxes2d": gt_boxes2d,
                "gt_number": gt_number,
                "crop_boxes2d": crop_boxes2d,
            }
        else:
            raise NotImplementedError

        return img_aug_gt_dict

    def collect_image_crops(
        self,
        img_aug_gt_dict,
        info,
        data_dict,
        obj_points,
        sampled_gt_boxes,
        sampled_gt_boxes2d,
        idx,
    ):
        if self.img_aug_type == "kitti":
            new_box, img_crop2d, obj_points, obj_idx = self.collect_image_crops_kitti(
                info, data_dict, obj_points, sampled_gt_boxes, sampled_gt_boxes2d, idx
            )
            img_aug_gt_dict["crop_boxes2d"].append(new_box)
            img_aug_gt_dict["gt_crops2d"].append(img_crop2d)
            img_aug_gt_dict["obj_index_list"].append(obj_idx)
        else:
            raise NotImplementedError

        return img_aug_gt_dict, obj_points

    def copy_paste_to_image(self, img_aug_gt_dict, data_dict, points):
        if self.img_aug_type == "kitti":
            obj_points_idx = np.concatenate(img_aug_gt_dict["obj_index_list"], axis=0)
            point_idxes = -1 * np.ones(len(points), dtype=np.int)
            point_idxes[: obj_points_idx.shape[0]] = obj_points_idx

            data_dict["gt_boxes2d"] = np.concatenate(
                [img_aug_gt_dict["gt_boxes2d"], np.array(img_aug_gt_dict["crop_boxes2d"])], axis=0
            )
            data_dict = self.copy_paste_to_image_kitti(
                data_dict, img_aug_gt_dict["gt_crops2d"], img_aug_gt_dict["gt_number"], point_idxes
            )
            if "road_plane" in data_dict:
                data_dict.pop("road_plane")
        else:
            raise NotImplementedError
        return data_dict

    def add_sampled_boxes_to_scene(
        self,
        data_dict,
        sampled_gt_boxes,
        total_valid_sampled_dict,
        mv_height=None,
        sampled_gt_boxes2d=None,
    ):
        gt_boxes_mask = data_dict["gt_boxes_mask"]
        gt_boxes = data_dict["gt_boxes"][gt_boxes_mask]
        gt_names = data_dict["gt_names"][gt_boxes_mask]
        points = data_dict["points"]
        if self.sampler_cfg.get("USE_ROAD_PLANE", False) and mv_height is None:
            sampled_gt_boxes, mv_height = self.put_boxes_on_road_planes(
                sampled_gt_boxes, data_dict["road_plane"], data_dict["calib"]
            )
            data_dict.pop("calib")
            data_dict.pop("road_plane")

        obj_points_list = []

        # convert sampled 3D boxes to image plane
        img_aug_gt_dict = self.initilize_image_aug_dict(data_dict, gt_boxes_mask)

        if self.use_shared_memory:
            gt_database_data = SharedArray.attach(f"shm://{self.gt_database_data_key}")
            gt_database_data.setflags(write=0)
        else:
            gt_database_data = None

        for idx, info in enumerate(total_valid_sampled_dict):
            if self.use_shared_memory:
                start_offset, end_offset = info["global_data_offset"]
                obj_points = copy.deepcopy(gt_database_data[start_offset:end_offset])
            else:
                file_path = self.root_path / info["path"]

                obj_points = np.fromfile(str(file_path), dtype=np.float32).reshape(
                    [-1, self.sampler_cfg.NUM_POINT_FEATURES]
                )
                if obj_points.shape[0] != info["num_points_in_gt"]:
                    obj_points = np.fromfile(str(file_path), dtype=np.float64).reshape(
                        -1, self.sampler_cfg.NUM_POINT_FEATURES
                    )

            assert obj_points.shape[0] == info["num_points_in_gt"]
            obj_points[:, :3] += info["box3d_lidar"][:3].astype(np.float32)

            if self.sampler_cfg.get("USE_ROAD_PLANE", False):
                # mv height
                obj_points[:, 2] -= mv_height[idx]

            if self.img_aug_type is not None:
                img_aug_gt_dict, obj_points = self.collect_image_crops(
                    img_aug_gt_dict,
                    info,
                    data_dict,
                    obj_points,
                    sampled_gt_boxes,
                    sampled_gt_boxes2d,
                    idx,
                )

            obj_points_list.append(obj_points)

        obj_points = np.concatenate(obj_points_list, axis=0)
        sampled_gt_names = np.array([x["name"] for x in total_valid_sampled_dict])

        if (
            self.sampler_cfg.get("FILTER_OBJ_POINTS_BY_TIMESTAMP", False)
            or obj_points.shape[-1] != points.shape[-1]
        ):
            if self.sampler_cfg.get("FILTER_OBJ_POINTS_BY_TIMESTAMP", False):
                min_time = min(self.sampler_cfg.TIME_RANGE[0], self.sampler_cfg.TIME_RANGE[1])
                max_time = max(self.sampler_cfg.TIME_RANGE[0], self.sampler_cfg.TIME_RANGE[1])
            else:
                assert obj_points.shape[-1] == points.shape[-1] + 1
                # transform multi-frame GT points to single-frame GT points
                min_time = max_time = 0.0

            time_mask = np.logical_and(
                obj_points[:, -1] < max_time + 1e-6, obj_points[:, -1] > min_time - 1e-6
            )
            obj_points = obj_points[time_mask]

        large_sampled_gt_boxes = box_utils.enlarge_box3d(
            sampled_gt_boxes[:, 0:7], extra_width=self.sampler_cfg.REMOVE_EXTRA_WIDTH
        )
        points = box_utils.remove_points_in_boxes3d(points, large_sampled_gt_boxes)
        points = np.concatenate([obj_points[:, : points.shape[-1]], points], axis=0)
        gt_names = np.concatenate([gt_names, sampled_gt_names], axis=0)
        gt_boxes = np.concatenate([gt_boxes, sampled_gt_boxes], axis=0)
        data_dict["gt_boxes"] = gt_boxes
        data_dict["gt_names"] = gt_names
        data_dict["points"] = points

        if self.img_aug_type is not None:
            data_dict = self.copy_paste_to_image(img_aug_gt_dict, data_dict, points)

        return data_dict

    def __call__(self, data_dict):
        """
        Args:
            data_dict:
                gt_boxes: (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]

        Returns:

        """
        gt_boxes = data_dict["gt_boxes"]
        gt_names = data_dict["gt_names"].astype(str)
        existed_boxes = gt_boxes
        total_valid_sampled_dict = []
        sampled_mv_height = []
        sampled_gt_boxes2d = []

        for class_name, sample_group in self.sample_groups.items():
            if self.limit_whole_scene:
                num_gt = np.sum(class_name == gt_names)
                sample_group["sample_num"] = str(int(self.sample_class_num[class_name]) - num_gt)
            if int(sample_group["sample_num"]) > 0:
                sampled_dict = self.sample_with_fixed_number(class_name, sample_group)

                sampled_boxes = np.stack([x["box3d_lidar"] for x in sampled_dict], axis=0).astype(
                    np.float32
                )

                assert not self.sampler_cfg.get(
                    "DATABASE_WITH_FAKELIDAR", False
                ), "Please use latest codes to generate GT_DATABASE"

                iou1 = iou3d_nms_utils.boxes_bev_iou_cpu(
                    sampled_boxes[:, 0:7], existed_boxes[:, 0:7]
                )
                iou2 = iou3d_nms_utils.boxes_bev_iou_cpu(
                    sampled_boxes[:, 0:7], sampled_boxes[:, 0:7]
                )
                iou2[range(sampled_boxes.shape[0]), range(sampled_boxes.shape[0])] = 0
                iou1 = iou1 if iou1.shape[1] > 0 else iou2
                valid_mask = (iou1.max(axis=1) + iou2.max(axis=1)) == 0

                if self.img_aug_type is not None:
                    sampled_boxes2d, mv_height, valid_mask = self.sample_gt_boxes_2d(
                        data_dict, sampled_boxes, valid_mask
                    )
                    sampled_gt_boxes2d.append(sampled_boxes2d)
                    if mv_height is not None:
                        sampled_mv_height.append(mv_height)

                valid_mask = valid_mask.nonzero()[0]
                valid_sampled_dict = [sampled_dict[x] for x in valid_mask]
                valid_sampled_boxes = sampled_boxes[valid_mask]

                existed_boxes = np.concatenate(
                    (existed_boxes, valid_sampled_boxes[:, : existed_boxes.shape[-1]]), axis=0
                )
                total_valid_sampled_dict.extend(valid_sampled_dict)

        sampled_gt_boxes = existed_boxes[gt_boxes.shape[0] :, :]

        if total_valid_sampled_dict.__len__() > 0:
            sampled_gt_boxes2d = (
                np.concatenate(sampled_gt_boxes2d, axis=0) if len(sampled_gt_boxes2d) > 0 else None
            )
            sampled_mv_height = (
                np.concatenate(sampled_mv_height, axis=0) if len(sampled_mv_height) > 0 else None
            )

            data_dict = self.add_sampled_boxes_to_scene(
                data_dict,
                sampled_gt_boxes,
                total_valid_sampled_dict,
                sampled_mv_height,
                sampled_gt_boxes2d,
            )

        data_dict.pop("gt_boxes_mask")
        return data_dict
